package com.zeyu.framework.core.distributed;

import com.google.common.collect.Sets;
import com.zeyu.framework.core.common.Constant;
import com.zeyu.framework.core.security.session.SessionDAO;
import com.zeyu.framework.core.web.servlet.Servlets;
import com.zeyu.framework.utils.DateUtils;
import com.zeyu.framework.utils.SerializeUtils;
import com.zeyu.framework.utils.StringUtils;
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.SimpleSession;
import org.apache.shiro.session.mgt.ValidatingSession;
import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.support.DefaultSubjectContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;

import javax.servlet.http.HttpServletRequest;
import java.io.Serializable;
import java.util.Collection;
import java.util.Date;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

/**
 * 自定义授权会话管理类
 * 通过Redis管理Session,需要实现类似DAO接口的CRUD即可。
 * 1：最开始通过继承AbstractSessionDAO实现，发现doReadSession方法调用过于频繁，所以改为通过集成CachingSessionDAO来实现。
 * 注意，本地缓存通过EhCache实现，失效时间一定要远小于Redis失效时间，这样本地失效后，会访问Redis读取，
 * 并重新设置Redis上会话数据的过期时间
 * 2. 针对自定义的ShiroSession的Redis CRUD操作，通过isChanged标识符，确定是否需要调用Update方法
 * 3. 通过配置securityManager在属性cacheManager查找从缓存中查找Session是否存在，如果找不到才调用下面方法
 * 4. Shiro内部相应的组件（DefaultSecurityManager）会自动检测相应的对象（如Realm）是否实现了
 * CacheManagerAware并自动注入相应的CacheManager
 */
public class RedisSessionDAO extends CachingSessionDAO implements SessionDAO, Constant {

    // ================================================================
    // Constants
    // ================================================================

    /**
     * Logger
     */
    private Logger logger = LoggerFactory.getLogger(getClass());

    // ================================================================
    // Fields
    // ================================================================

    /**
     * 设置会话的过期时间
     */
    private int seconds = 1800;

    /**
     * 特殊配置 只用于没有Redis时 将Session放到EhCache中
     */
    private Boolean onlyEhCache;

    /**
     * session key
     * 保存到Redis中key的前缀 prefix+sessionId
     */
    private String sessionKeyPrefix = "shiro_session_";

    @Autowired
    private RedisTemplate<String, String> redisTemplate;

    // ================================================================
    // Constructors
    // ================================================================

    // ================================================================
    // Methods from/for super Interfaces or SuperClass
    // ================================================================

    /**
     * 重写CachingSessionDAO中readSession方法，如果Session中没有登陆信息就调用doReadSession方法从Redis中重读
     */
    @Override
    public Session readSession(Serializable sessionId) throws UnknownSessionException {
        Session cached = null;
        try {
            cached = super.getCachedSession(sessionId);
        } catch (Exception e) {
            logger.error("get cache session error.", e);
        }
        if (onlyEhCache) {
            return cached;
        }
        // 如果缓存不存在或者缓存中没有登陆认证后记录的信息就重新从Redis中读取
        // session.getAttribute(DefaultSubjectContext.PRINCIPALS_SESSION_KEY) == null 代表没有登录，登录后Shiro会放入该值
        if (cached == null || cached.getAttribute(DefaultSubjectContext.PRINCIPALS_SESSION_KEY) == null) {
            try {
                cached = this.doReadSession(sessionId);
                if (cached == null) {
                    throw new UnknownSessionException();
                } else {
                    // 重置Redis中缓存过期时间并缓存起来 只有设置change才能更改最后一次访问时间
                    ((ShiroSession) cached).setChanged(true);
                    super.update(cached);
                }
            } catch (Exception e) {
                logger.warn("There is no session with id [" + sessionId + "]");
            }
        }
        return cached;
    }


    /**
     * 根据会话ID获取会话
     *
     * @param sessionId 会话ID
     * @return ShiroSession
     */
    @Override
    protected Session doReadSession(Serializable sessionId) {

        HttpServletRequest request = Servlets.getRequest();

        Map<String, Object> vals = getSession(request, sessionId);
        if (Boolean.valueOf(vals.get("static").toString())) {
            // 静态文件
            return null;
        }

        if (vals.get("session") != null && vals.get("session") instanceof Session) {
            return (Session) vals.get("session");
        }

        Session session = null;
        try {
            Object encddeSession = redisTemplate.opsForValue().get(sessionKeyPrefix + sessionId);
            if (encddeSession != null)
                session = SerializeUtils.deserializeFromString(encddeSession.toString());

            // 重置Redis中缓存过期时间
            if (session != null && session.getId() != null)
                redisTemplate.expire((sessionKeyPrefix + session.getId()), seconds, TimeUnit.SECONDS);

            logger.debug("doReadSession {} {}", sessionId, request != null ? request.getRequestURI() : "");
        } catch (Exception e) {
            logger.error("doReadSession {} {}", sessionId, request != null ? request.getRequestURI() : "", e);
        }

        if (request != null && session != null) {
            request.setAttribute("session_" + sessionId, session);
        }

        return session;
    }

    /**
     * 如DefaultSessionManager在创建完session后会调用该方法；
     * 如保存到关系数据库/文件系统/NoSQL数据库；即可以实现会话的持久化；
     * 返回会话ID；主要此处返回的ID.equals(session.getId())；
     */
    @Override
    protected Serializable doCreate(Session session) {
        HttpServletRequest request = Servlets.getRequest();
        if (request != null) {
            String uri = request.getServletPath();
            // 如果是静态文件，则不创建SESSION
            if (Servlets.isStaticFile(uri)) {
                return null;
            }
        }
        // 创建一个Id并设置给Session
        Serializable sessionId = this.generateSessionId(session);
        assignSessionId(session, sessionId);
        if (onlyEhCache) {
            return sessionId;
        }
        try {
            // session由Redis缓存失效决定，这里只是简单标识
            session.setTimeout(seconds);

            this.doUpdate(session);
            logger.info("sessionId {} name {} 被创建", sessionId, session.getClass().getName());
        } catch (Exception e) {
            logger.warn("创建Session失败", e);
        }
        return sessionId;
    }

    /**
     * 更新会话；如更新会话最后访问时间/停止会话/设置超时时间/设置移除属性等会调用
     */
    @Override
    protected void doUpdate(Session session) {
        // 错误判断
        if (session == null || session.getId() == null) {
            return;
        }
        //如果会话过期/停止 没必要再更新了
        try {
            if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
                return;
            }
        } catch (Exception e) {
            logger.error("ValidatingSession error");
        }
        if (onlyEhCache) {
            return;
        }

        HttpServletRequest request = Servlets.getRequest();
        if (!needUpdate(request)) {
            return;
        }

        try {
            if (session instanceof ShiroSession) {
                // 如果没有主要字段(除lastAccessTime以外其他字段)发生改变
                ShiroSession ss = (ShiroSession) session;
                if (!ss.isChanged()) {
                    return;
                }
                ss.setChanged(false);
                ss.setLastAccessTime(new Date());

                updateSession(session);
                logger.debug("update shiro session {} {}", session.getId(), request != null ? request.getRequestURI() : "");

                //发送广播
            } else if (session instanceof Serializable) {
                updateSession(session);
                logger.info("ID {} classname {} 作为非ShiroSession对象被更新, ", session.getId(), session.getClass().getName());
            } else {
                logger.debug("sessionId {} name {} 更新失败", session.getId(), session.getClass().getName());
            }
        } catch (Exception e) {
            logger.warn("更新Session失败", e);
        }
    }

    /**
     * 删除会话；当会话过期/会话停止（如用户退出时）会调用
     */
    @Override
    public void doDelete(Session session) {
        logger.debug("begin doDelete {} ", session);
        try {

            if (session == null || session.getId() == null) {
                return;
            }

            redisTemplate.opsForHash().delete(sessionKeyPrefix, session.getId());
            redisTemplate.delete(sessionKeyPrefix + session.getId());

            this.uncache(session.getId());
            logger.debug("shiro session id {} 被删除", session.getId());
        } catch (Exception e) {
            logger.warn("删除Session失败", e);
        }
    }

    /**
     * 删除cache中缓存的Session
     */
    public void uncache(Serializable sessionId) {
        try {
            Session session = super.getCachedSession(sessionId);
            super.uncache(session);
            logger.debug("删除本地 cache中缓存的Session id {} 的缓存失效", sessionId);
        } catch (Exception e) {
            logger.error("delete cache error.", e);
        }
    }

    /**
     * 获取当前所有活跃用户，如果用户量多此方法影响性能
     */
    @Override
    public Collection<Session> getActiveSessions() {
        return getActiveSessions(true);
    }

    /**
     * 获取活动会话
     *
     * @param includeLeave 是否包括离线（最后访问时间大于3分钟为离线会话）
     */
    @Override
    public Collection<Session> getActiveSessions(boolean includeLeave) {
        return getActiveSessions(includeLeave, null, null);
    }

    /**
     * 获取活动会话
     *
     * @param includeLeave  是否包括离线（最后访问时间大于3分钟为离线会话）
     * @param principal     根据登录者对象获取活动会话
     * @param filterSession 不为空，则过滤掉（不包含）这个会话。
     */
    @Override
    public Collection<Session> getActiveSessions(boolean includeLeave, Object principal, Session filterSession) {
        Set<Session> sessions = Sets.newHashSet();

        try {
            Map<Object, Object> map = redisTemplate.opsForHash().entries(sessionKeyPrefix);
            for (Map.Entry<Object, Object> e : map.entrySet()) {
                if (StringUtils.isNotBlank(e.getKey().toString()) && StringUtils.isNotBlank(e.getValue().toString())) {

                    String[] ss = StringUtils.split(e.getValue().toString(), "|");
                    if (ss != null && ss.length == 3) {// jedis.exists(sessionKeyPrefix + e.getKey())){
                        // Session session = (Session)JedisUtils.toObject(jedis.get(JedisUtils.getBytesKey(sessionKeyPrefix + e.getKey())));
                        SimpleSession session = new SimpleSession();
                        session.setId(e.getKey().toString());
                        session.setAttribute("principalId", ss[0]);
                        session.setTimeout(Long.valueOf(ss[1]));
                        session.setLastAccessTime(new Date(Long.valueOf(ss[2])));
                        try {
                            // 验证SESSION
                            session.validate();
                            if (logger.isDebugEnabled()) {
                                logger.debug("session {} is used", session);
                            }

                            boolean isActiveSession = false;
                            // 不包括离线并符合最后访问时间小于等于3分钟条件。
                            if (includeLeave || DateUtils.pastMinutes(session.getLastAccessTime()) <= 3) {
                                isActiveSession = true;
                            }

                            // 符合登陆者条件。
                            if (principal != null) {
                                PrincipalCollection pc = (PrincipalCollection) session.getAttribute(DefaultSubjectContext.PRINCIPALS_SESSION_KEY);
                                if (principal.toString().equals(pc != null ? pc.getPrimaryPrincipal().toString() : StringUtils.EMPTY)) {
                                    isActiveSession = true;
                                }
                            }
                            // 过滤掉的SESSION
                            if (filterSession != null && filterSession.getId().equals(session.getId())) {
                                isActiveSession = false;
                            }
                            if (isActiveSession) {
                                sessions.add(session);
                            }

                        }
                        // SESSION验证失败
                        catch (Exception ex) {
                            redisTemplate.opsForHash().delete(sessionKeyPrefix, e.getKey());
                        }
                    }
                    // 存储的SESSION不符合规则
                    else {
                        redisTemplate.opsForHash().delete(sessionKeyPrefix, e.getKey());
                    }
                }
                // 存储的SESSION无Value
                else if (StringUtils.isNotBlank(e.getKey().toString())) {
                    redisTemplate.opsForHash().delete(sessionKeyPrefix, e.getKey());
                }
            }
            logger.info("getActiveSessions size: {} ", sessions.size());
        } catch (Exception e) {
            logger.error("getActiveSessions in distributed", e);
        }
        return sessions;
    }

    /**
     * 返回本机Ehcache中Session
     */
    public Collection<Session> getEhCacheActiveSessions() {
        return super.getActiveSessions();
    }

    // ================================================================
    // Public or Protected Methods
    // ================================================================

    /**
     * 从Redis中读取，但不重置Redis中缓存过期时间
     */
    public Session doReadSessionWithoutExpire(Serializable sessionId) {
        if (onlyEhCache) {
            return readSession(sessionId);
        }

        Session session = null;
        try {
            session = SerializeUtils.deserializeFromString(redisTemplate.opsForValue().get(sessionKeyPrefix + sessionId));

            logger.debug("doReadSession no expire {} ", sessionId);
        } catch (Exception e) {
            logger.error("doReadSession no expire{} ", sessionId, e);
        }

        return session;
    }

    // ================================================================
    // Getter & Setter
    // ================================================================

    public void setSessionKeyPrefix(String sessionKeyPrefix) {
        this.sessionKeyPrefix = sessionKeyPrefix;
    }

    public void setSeconds(int seconds) {
        this.seconds = seconds;
    }

    public void setOnlyEhCache(Boolean onlyEhCache) {
        this.onlyEhCache = onlyEhCache;
    }

    // ================================================================
    // Private Methods
    // ================================================================

    private void updateSession(Session session) {
        // 获取登录者编号
        PrincipalCollection pc = (PrincipalCollection) session.getAttribute(DefaultSubjectContext.PRINCIPALS_SESSION_KEY);
        String principalId = pc != null ? pc.getPrimaryPrincipal().toString() : StringUtils.EMPTY;

        redisTemplate.opsForHash().put(sessionKeyPrefix, session.getId().toString(),
                principalId + "|" + session.getTimeout() + "|" + session.getLastAccessTime().getTime());
        redisTemplate.opsForValue().set(sessionKeyPrefix + session.getId(), SerializeUtils.serializeToString((Serializable) session));

        // 设置超期时间,统一由redis管理
        // int timeoutSeconds = (int) (session.getTimeout() / 1000);
        redisTemplate.expire((sessionKeyPrefix + session.getId()), seconds, TimeUnit.SECONDS);
    }

    // ================================================================
    // Inner or Anonymous Class
    // ================================================================

    // ================================================================
    // Test Methods
    // ================================================================
}
